load("//tools/install:install.bzl", "install")
load(
    "//tools/skylark:pybind.bzl",
    "get_pybind_package_info",
)
load(
    "//tools/skylark:drake_py.bzl",
    "drake_py_binary",
    "drake_py_library",
    "drake_py_test",
)
load(
    "@drake_detected_os//:os.bzl",
    "UBUNTU_RELEASE",
)
load("//bindings/pydrake:pydrake.bzl", "add_lint_tests_pydrake")

package(default_visibility = [
    "//bindings/pydrake:__subpackages__",
])

# This determines how `PYTHONPATH` is configured, and how to install the
# bindings.
PACKAGE_INFO = get_pybind_package_info("//bindings")

drake_py_library(
    name = "module_py",
    srcs = ["__init__.py"],
    deps = [
        "//bindings/pydrake/examples",
    ],
)

drake_py_library(
    name = "named_view_helpers_py",
    srcs = ["named_view_helpers.py"],
    deps = [
        ":module_py",
        "//bindings/pydrake/multibody",
    ],
)

drake_py_library(
    name = "bazel_cwd_helpers_py",
    srcs = ["_bazel_cwd_helpers.py"],
    deps = [
        ":module_py",
    ],
)

drake_py_library(
    name = "cart_pole_py",
    srcs = ["envs/cart_pole.py"],
    data = [
        "models/cartpole_BSA.sdf",
    ],
    deps = [
        ":module_py",
        ":named_view_helpers_py",
        "//bindings/pydrake/geometry",
        "//bindings/pydrake/gym",
        "//bindings/pydrake/multibody",
        "//bindings/pydrake/systems",
    ],
)

drake_py_library(
    name = "cart_pole_binaries",
    srcs = [
        "play_cart_pole.py",
        "train_cart_pole.py",
    ],
    deps = [
        ":bazel_cwd_helpers_py",
        ":cart_pole_py",
        ":module_py",
        "//bindings/pydrake/geometry",
    ],
)

drake_py_binary(
    name = "train_cart_pole",
    srcs = ["train_cart_pole.py"],
    deps = [
        ":cart_pole_binaries",
    ],
)

drake_py_binary(
    name = "play_cart_pole",
    srcs = ["play_cart_pole.py"],
    deps = [
        ":cart_pole_binaries",
    ],
)

# Get some test coverage on the example; `--test` limits its time
# and CPU budget.
drake_py_test(
    name = "train_cart_pole_test",
    srcs = ["train_cart_pole.py"],
    args = ["--test"],
    main = "train_cart_pole.py",
    tags = ["manual"] if UBUNTU_RELEASE == "20.04" else [],
    deps = [
        ":cart_pole_binaries",
        "@gymnasium_py",
        "@stable_baselines3_internal//:stable_baselines3",
    ],
)

drake_py_test(
    name = "play_cart_pole_test",
    srcs = ["play_cart_pole.py"],
    args = ["--test"],
    main = "play_cart_pole.py",
    tags = ["manual"] if UBUNTU_RELEASE == "20.04" else [],
    deps = [
        ":cart_pole_binaries",
        "@gymnasium_py",
        "@stable_baselines3_internal//:stable_baselines3",
    ],
)

PY_LIBRARIES = [
    ":cart_pole_binaries",
    ":cart_pole_py",
    ":named_view_helpers_py",
]

# Package roll-up (for Bazel dependencies).
# N.B. `examples` packages do not have `all` modules.
drake_py_library(
    name = "gym",
    imports = PACKAGE_INFO.py_imports,
    deps = PY_LIBRARIES,
)

install(
    name = "install",
    targets = PY_LIBRARIES,
    py_dest = PACKAGE_INFO.py_dest,
)

add_lint_tests_pydrake()
